#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os,requests
from libs import dbsql,mime
from requests_toolbelt.multipart.encoder import MultipartEncoder
import re
import traceback
PROJECT_PATH = dbsql.PROJECT_PATH
DECORATORS_PATH = os.path.join(PROJECT_PATH,'decorators')
MULTIPARTFILE_PATH = os.path.join(DECORATORS_PATH,"multipartfile")
SUF = ".txt"
def get_files(fDir,suf):
    '''
    @des
    获取指定目录下指定文件格式的文件名
    suf为空，则获取完整文件名
    '''
    result = []
    for _,_,files in os.walk(fDir):
        for f in files:
            name,suffix = os.path.splitext(f)
            if suf:
                if suffix == suf:
                    result.append(name)
            else:
                result.append(f)
    return result

def get_headers(sysName):
    headers = dbsql.mitmex_get_headers_by_system_name("httpmultipart",sysName)
    return headers
def get_url(sysName):
    url = dbsql.mitmex_get_url_by_system_name("httpmultipart",sysName)
    return url
def get_params(sysName):
    paramsFile = dbsql.mitmex_get_params_by_system_name("httpmultipart",sysName)
    pfiles = get_files(MULTIPARTFILE_PATH,"")
    result = {}
    content = ''
    if paramsFile and pfiles:
        for f in pfiles:
            if paramsFile == f:
                fpath = os.path.join(MULTIPARTFILE_PATH,paramsFile)
                with open(fpath,'r') as fread:
                    content = fread.readline()
                break
        if content:
            result = eval(content)
    return result
def get_boundary(content:str):
    result = ""
    rlist = []
    rlist = re.findall(r'boundary=(.*)',content)
    if rlist:
        result = rlist[0]
    return result
def format_params(params:dict):
    fparams = {}
    fileContent = b''
    if params:
        print("字段名称:")
        for k in params.keys():
            print(k)
        fileFiled = input("上传文件的字段名: ")
        fileName = input("上传的文件名称: ")
        fileSuf = fileName.split(".")[-1]
        fmime = mime.get_mimetype(fileSuf)
        if fileFiled and fmime:
            fileFiled = fileFiled.strip()
            for p in params.keys():
                if fileFiled == p:
                    fileContent = params[p]
                    mfile = (fileName,fileContent,fmime)
                    fparams[p] = mfile
                    continue
                fparams[p] = str(params[p], encoding="utf-8")
            print("参数:")
            for k,v in fparams.items():
                if k == fileFiled:
                    v = "二进制内容"
                print(k + " : " + v)
            fixpname = input("输入修改的参数名: " )
            fixpvalue = input("输入修改的内容: ")
            fixpname = fixpname.strip()
            fixpvalue = fixpvalue.strip()
            if fixpname and fixpvalue:
                fparams[fixpname] = fixpvalue
            else:
                return {}         
    return fparams
def get_multipart_encoder_headers_url():
    sysName = input("系统名称(url或Ip): ")
    sysName = sysName.strip()
    if sysName:
        params = get_params(sysName)
        headers = get_headers(sysName)
        url = get_url(sysName)
        if params:
            tparams = format_params(params)
            boundary = get_boundary(headers['Content-Type'])
            if tparams and boundary: 
                multipart_encoder = MultipartEncoder(
                    fields= tparams,
                    boundary=boundary
                )
                return multipart_encoder,headers,url
    return '','',''

def execute(args):
    try:
        multipart_encoder,headers,url = get_multipart_encoder_headers_url()
        if multipart_encoder:
            result = requests.post(url, headers=headers, data=multipart_encoder)
            print(result.text)
        else:
            print("数据错误!")
    except Exception as e:
        print("程序异常!")
        traceback.print_exc()

                






